{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linear Layer\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "Sometimes, Linear Layers are also called Dense Layers, like in the toolkit Keras.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What do linear layers do?\n",
    "\n",
    "A linear layer transforms a vector into another vector. For example, you can transform a vector `[1, 2, 3]` to `[1, 2, 3, 4]` with a linear layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## When to use linear layers?\n",
    "\n",
    "Use linear layers when you want to change a vector into another vector. This often happens when the target vector's shape is different from the vector at hand.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "```{note}\n",
    "Linear layers are often called linear transformation or linear mapping.\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## How does a linear layer work?\n",
    "\n",
    "There are two components in a linear layer. A weight $ W $, and a bias $ B $. If the input of a linear layer is a vector $ X $, then the output is $ W X + B $.\n",
    "\n",
    "If the linear layer transforms a vector of dimension $ N $ to dimension $ M $, then $ W $ is a $ M \\times N $ matrix, $ X $ is of dimension $ N $, $ B $ is of dimension $ M $."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Linear layers in code?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import Linear"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "linear = Linear(3, 4)\n",
    "print(linear.weight.detach())\n",
    "print(linear.bias.detach())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You see, linear layers are just 2 matrices, weight and bias."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor([1., 2., 3.])\n",
    "y1 = linear(x)\n",
    "y2 = linear.weight @ x + linear.bias\n",
    "print(y1)\n",
    "print(y2)\n",
    "print(y1 == y2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All that a linear layer do is to `matmul` the input vector, then added by the bias. It's the linear algebra notation of $ WX+B $, with $ W $ the weight matrix, and $ B $ the bias vector."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
